#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Exits with status 1 if any test failed


from __future__ import division
from __future__ import print_function
from builtins import str

import numpy as np

#Requires: boutcore
#requires: not make

import boutcore as bc

bc.init("-q -q -q")


funcs = [
    ['sin(z)', 'sin(4*x)', '4*cos(z)*cos(4*x)']
]

# List of NX values to use
nlist = [8, 16, 32]

prec = 0.2


def genMesh(nx, ny, nz, **kwargs):
    bc.setOption("mesh:mxg", str(2 if nx > 1 else 0), force=True)
    bc.setOption("mesh:myg", str(2 if ny > 1 else 0), force=True)
    bc.setOption("mesh:nx", str(nx + 4 if nx > 1 else nx), force=True)
    bc.setOption("mesh:ny", str(ny), force=True)
    bc.setOption("mz", str(nz), force=True)
    bc.setOption("mesh:nz", str(nz), force=True)
    bc.setOption("mesh:dx", "2*pi/(%d)" % (nx), force=True)
    bc.setOption("mesh:dy", "1/(%d)" % (ny), force=True)
    bc.setOption("mesh:dz", "1/(%d)" % (nz), force=True)
    for k, v in kwargs.items():
        bc.setOption(k, v, force=True)
    return bc.Mesh(section="mesh")


errlist = ""
brackets = [["ARAKAWA", 2, {}],
            ["STD", 1, {"mesh:ddx:upwind": "U1",
                        "mesh:ddz:upwind": "U1"}
             ],
            ["STD", 2, {"mesh:ddx:upwind": "C2",
                        "mesh:ddz:upwind": "C2"}
             ],
            # Weno convergence order is currently under discusision:
            # https://github.com/boutproject/BOUT-dev/issues/1049
            # ["STD", 3 , {"mesh:ddx:upwind":"W3",
            #              "mesh:ddz:upwind":"W3"}
            #  ]
            ]
for in1, in2, outfunc in funcs:
    for name, order, args in brackets:
        errors = []
        for n in nlist[-2:]:
            mesh = genMesh(n, 1, n, **args)
            inf1 = bc.create3D(in1, mesh)
            inf2 = bc.create3D(in2, mesh)
            inf = bc.bracket(inf1, inf2, method=name)
            outf = bc.create3D(outfunc, mesh)
            err = (inf-outf)
            slize = [slice(2, -2), 0, slice(None)]
            err = err[slize]
            err = np.sqrt(np.mean(err**2))
            errors.append(err)
        errc = np.log(errors[-2]/errors[-1])
        difc = np.log(nlist[-1]/nlist[-2])
        conv = errc/difc
        if order-prec < conv < order+prec:
            pass
        else:
            err = "{%s,%s} -> %s failed (conv: %f - expected: %f)\n" % (
                in1, in2, outfunc, conv, order)
            errlist += err

if errlist:
    print(errlist)
    exit(1)

# FIXME: this is a workaround for a strange bug, MPI_Comm_free after
# MPI_Finalize. Seems to be from the global mesh getting destroyed
# after/at the same time as finalise is called -- hence explicitly
# deleting it here
mesh = bc.Mesh.getGlobal()
del mesh
bc.finalise()

print("All tests passed")
exit(0)
